import backtrader as bt
import click
from backtrader.utils import logger, datehelper
from backtrader.utils.cmd import cli
from datetime import datetime, timedelta
import backtrader.stores as storemgt
import json
import time
import os


class LiveRun:
    def __init__(self, **kwargs):
        self.logger = logger.getLogger(__name__)
        self.account = bt.AccountMgt().get_account(kwargs["account"])
        self.logger.info("%s Input Args %s", "*" * 15, "*" * 15)
        self.print_params(kwargs)

        token = kwargs["alert_token"]
        alert_enable = True if kwargs["alert_enable"].upper() == "TRUE" else False

        _debug = True if kwargs["debug"].upper() == "TRUE" else False

        self.debug = _debug

        self.alert = bt.utils.Alert(token, alert_enable)
        self.strat_code = kwargs["strat"]
        self.symbols = kwargs["symbols"]
        self.currency = list(map(lambda x: x.replace("/USDT", ""), self.symbols))
        self.asset_type = kwargs["type"]
        self.time_frames = kwargs["timeframe"]
        self.resample = kwargs["resample"]
        self.exchange = kwargs["exchange"]
        self.params = kwargs["params"]
        self.max_amount = kwargs["max_amount"]
        self.max_size = kwargs["max_size"]
        self.ordtype = kwargs["ordtype"]
        self.slippage = kwargs["slippage"]
        self.lookbacks = kwargs["lookback"]
        self.leverage = kwargs["leverage"]
        self.max_trade = kwargs["max_trade"]
        assert len(self.lookbacks) == len(self.time_frames)

    def print_strategy_config(self):
        for k in self.params.keys():
            self.logger.info("Strategy: %s=%s", k, self.params[k])

    def print_params(self, kwargs):
        self.logger.info(kwargs)

    def format_symbol(self, symbol, timeframe):
        return symbol, symbol.replace("/", "_") + f"_{timeframe}"

    def run(self):

        self.logger.info("%s Program Args %s", "*" * 15, "*" * 15)
        self.logger.info("Start: exchange=%s", self.exchange)
        self.logger.info("Start: @@@@【Type=%s】@@@@", self.asset_type)
        self.logger.info("Start: @@@@【account=%s】@@@@", self.account.name)
        self.logger.info("Start: @@@@【max amount=%s】@@@@", self.max_amount)
        self.logger.info("Start: @@@@【max size=%s】@@@@", self.max_size)
        self.logger.info("Start: @@@@【max ordtype=%s】@@@@", self.ordtype)
        self.logger.info("Start: @@@@【max slippage=%s】@@@@", self.slippage)
        self.logger.info("Start: @@@@【leverage=%s】@@@@", self.leverage)
        self.logger.info("Start: symbol=%s", ",".join(self.symbols))
        self.logger.info("Start: timeframe=%s", ",".join(self.time_frames))
        self.logger.info("%s", "*" * 40)

        # self.logger.info("%s Account Args %s", "*" * 15, "*" * 15)
        # self.print_account_config()

        self.logger.info("%s Strategy Args %s", "*" * 15, "*" * 15)
        self.print_strategy_config()

        cerebro = bt.Cerebro(quicknotify=True)

        strategy_name = strategy_code = self.strat_code
        strat = __import__(strategy_code)

        strat_names = [e for e in strat.__dict__.keys() if e.endswith("Strategy")]
        if not strat_names:
            raise Exception(f"策略名称必须以Strategy结尾.")
        strat_clz = strat_names[0]

        init_params = self.params.copy()
        init_params["name"] = strategy_name
        init_params["asset_type"] = self.asset_type
        init_params["account"] = self.account.name[0]
        init_params["max_amount"] = self.max_amount
        init_params["max_size"] = self.max_size
        init_params["ordtype"] = self.ordtype
        init_params["slippage"] = self.slippage
        init_params["leverage"] = self.leverage
        init_params["token"] = self.alert.token
        init_params["max_trade"] = self.max_trade
        init_params["enable"] = self.alert.enable
        init_params["debug"] = self.debug

        self.logger.info("Load Strategy: class=%s", strat_clz)
        self.logger.info("Load Strategy: name=%s", strategy_name)
        self.logger.info("Load Strategy: param=%s", init_params)

        # # Add the strategy
        cerebro.addstrategy(strat.__dict__[strat_clz], **init_params)

        # store = OkexStore(exchange=self.exchange, currency=self.currency, config=store_config, retries=5)
        api_key = {
            "apiKey": self.account.apiKey,
            "secret": self.account.secret,
            "password": self.account.password
        }

        # 动态获取交易所的store class文件
        store_cls = getattr(storemgt, f"{self.exchange.capitalize()}Store")

        # 获取这个交易所的配置文件
        exchange_config = store_cls.exchange_config.copy()
        exchange_config.update(api_key)

        exchange_config["options"]["defaultType"] = self.asset_type

        store = store_cls(exchange=self.exchange, currency=self.currency, config=exchange_config, retries=5)
        while True:
            try:
                store.fetch_ohlcv(symbol="BTC/USDT", timeframe="1h", since=None, limit=10)
                break
            except Exception as e:
                self.logger.info(e)
                time.sleep(60)
                continue

        if self.asset_type.upper() == "SPOT":
            broker = store.getbroker(broker_mapping=store_cls.broker_mapping)
        else:
            mapping = store_cls.broker_future_mapping.copy()
            mapping["future_config"] = {
                "leverage": self.leverage,
                "dualSidePosition": False,
                "isolated": True
            }
            broker = store.getfuturebroker(broker_mapping=mapping)
        cerebro.setbroker(broker)

        # ------------------
        for idx, tf in enumerate(self.time_frames):
            for symbol in self.symbols:
                timeframe, compression, total_minutes = bt.utils.datehelper.parse_timeframe(tf)
                lookback = self.lookbacks[idx]
                history_minutes = (lookback + 1) * total_minutes
                hist_start_date = datetime.utcnow() - timedelta(minutes=history_minutes)
                ohlcv_limit = lookback + 1

                _symbol, _name = self.format_symbol(symbol, tf)
                # data.p.dataname = dataname, dataname: 用来跟broker交互的名称
                # data._name = name, 用来区分数据的名称

                data = store.getdata(dataname=_symbol, name=_name,
                                     timeframe=timeframe, fromdate=hist_start_date,
                                     compression=compression, ohlcv_limit=ohlcv_limit,
                                     drop_newest=True)  # , historical=True)

                # Add the feed
                cerebro.adddata(data)

        for symbol in self.symbols:

            if self.resample:
                # 这里必须设置 bar2edge=False, 否则会报错
                _symbol, _name = self.format_symbol(symbol, self.resample)
                # _resample, _compression, _ = bt.utils.datehelper.parse_timeframe(self.resample)
                # cerebro.resampledata(dataname=data, name=_name,
                #                      timeframe=_resample, compression=_compression, bar2edge=False)

        try:
            # Run the strategy
            dtstr = datehelper.date2str(datetime.now())
            detail = "\nmax_amount: %s\ntime_frame: %s\n" \
                     "max_size: %s\nordtype: %s\nslippage: %s\n策略参数: %s" % \
                     (self.max_amount, ",".join(self.time_frames),
                      self.max_size, self.ordtype, self.slippage, self.params)
            self.alert.send(self.alert.format().format(name=strategy_name,
                                                       account=self.account.name[0],
                                                       asset=",".join(self.symbols),
                                                       date=dtstr,
                                                       action="程序启动",
                                                       detail=detail))

            cerebro.run(runonece=False)
        except Exception as e:
            dtstr = datehelper.date2str(datetime.now())
            detail = "%s: %s" % (e.__class__.__name__, e)
            self.alert.send(self.alert.format().format(name=strategy_name,
                                                       asset=",".join(self.symbols),
                                                       account=self.account.name[0],
                                                       date=dtstr,
                                                       action="程序崩溃",
                                                       detail=detail))
            raise e